Skip to content

[PT2E] Fix per-tensor observer issue with varying shape & rank #2177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Xia-Weiwen
Copy link
Collaborator

Fixes #2094 and #2112
We may find inputs with varying shapes and ranks, e.g. when running Resnet18. The current implementation is based on block_size, which is not enough for such cases. The fix is simple: use block_size = -1 for each dimension for per-tensor quantization and update block_size for each input when inserting q/dq in convert.

Copy link

pytorch-bot bot commented May 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2177

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b03b1e6 with merge base 07ca637 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 6, 2025
@Xia-Weiwen Xia-Weiwen added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label May 6, 2025
@Xia-Weiwen Xia-Weiwen changed the title [PT2E] Fix per-tensor observer issue with varing shape & rank [PT2E] Fix per-tensor observer issue with varying shape & rank May 6, 2025
@Xia-Weiwen Xia-Weiwen force-pushed the fix_per_tensor_quant branch from 87f1249 to 2ac41fb Compare May 6, 2025 12:12
@@ -1891,6 +1891,10 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node):
assert self.original_dtype is not None, (
"Expecting original_dtype to be populated"
)
# Since input shape & rank may change (e.g. Resnet18), here we need to update block_size for each input
self.block_size = get_block_size(
Copy link
Contributor

@jerryzh168 jerryzh168 May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when does this happen? can you give an example? I thought using -1 for dynamic dimensions will be enough?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reproduce the issue, you may run the code here: #2094 (comment)
You will have to using -1 for block_size without updating of self.block_size here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you saying the rank / number of dimension changes for input as well? can we use a single -1 to represent this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you saying the rank / number of dimension changes for input as well?

Yes

can we use a single -1 to represent this case?

I think it's doable. But there are checks to guard len(self.block_size) == len(input.shape). We need to handle the special case for per-tensor quant at these locations. Is it ok?

@Xia-Weiwen Xia-Weiwen requested a review from jerryzh168 May 8, 2025 07:15
@Xia-Weiwen
Copy link
Collaborator Author

@jerryzh168 Could you please review this PR? Thanks.

@Xia-Weiwen Xia-Weiwen requested a review from drisspg May 13, 2025 01:28
@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review May 13, 2025 01:28
@Xia-Weiwen
Copy link
Collaborator Author

Hi @jerryzh168 @drisspg Could you please review this PR? I am not sure if the current implementation is what you expected. Thanks.

@@ -113,7 +113,8 @@ def _get_reduction_params(block_size, input_size):
shape_for_reduction: (3, 3, 5, 2, 10)
reduction_dim: [0, 1, 3, 4]
"""
assert len(block_size) == len(input_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still used? we should be using the code in quant_primitives.py I think

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's still used when running the prepared model (model after prepare_pt2e). Is it a bug? Do I need to fix it, too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using the observers defined here: torchao/quantization/pt2e/_affine_quantization.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jerryzh168 May I know your suggestion on this? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be using the ones in torchao/quantization/observer.py eventually

only occurrence seems to be

AffineQuantizedMinMaxObserver,
and we want to update it I think

so if you are adding new things I'd recommend use the one from torchao.quantization

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Xia-Weiwen sorry for the delay, please feel free to work on this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we already use the one from torchao:

but if you saw we are using torch.ao please go ahead and change them

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 Thanks for the reply. I did not mean torch.ao. I meant there are two versions of such utilities in torchao, torchao.quantization.pt2e and torchao.quantization. For example,

class PartialWrapper:

and
class _PartialWrapper:

The PT2E flow in torchao uses those in torchao.quantization.pt2e while you said you wanted to switch to torchao/quantization/observer.py.
So, I was asking whether you would switch to torchao/quantization/observer.py in PT2E flow first. Do you have any suggestions on that? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, yeah for now use torchao/quantization/observer.py would be better I think, we haven't finalized the folder structure for this one yet

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 Am I supposed to wait until you finalize the folder structure? Thanks.

@jerryzh168
Copy link
Contributor

would be good to add a test for this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Quant][PT2E] AffineQuantized observers failed Resnet18
3 participants